PlutoCon 2021 WYSIWYR Demo (MNIST)
xxxxxxxxxx
Author: Connor Burns
xxxxxxxxxx
In this notebook we will load a pretrained model for classifying MNIST handwritten digits from 28x28 greyscale images. However, this notebook is less about the model and more about interacting with it via "what you see is what you REST" features.
xxxxxxxxxx
xxxxxxxxxx
Loading Data
To start off we will download the MNIST dataset using the MLDatasets
package.
xxxxxxxxxx
xxxxxxxxxx
MNIST.download(; i_accept_the_terms_of_use=true);
Now we load a pre-trained model which has been serialized with Julia's native serialization library. The model is made up of 3 convolutional layers, 3 max pooling layers, and one dense layer.
xxxxxxxxxx
Failed to connect to cot.llc port 80: Connection timed out while requesting http://cot.llc/mnist_conv
- (::Downloads.var"#9#18"{IOStream, Base.DevNull, Nothing, Vector{Pair{String, String}}, Float64, Nothing, Bool, Bool, String, Int64, Bool, Bool})(::Downloads.Curl.Easy)@Downloads.jl:356
- with_handle(::Downloads.var"#9#18"{IOStream, Base.DevNull, Nothing, Vector{Pair{String, String}}, Float64, Nothing, Bool, Bool, String, Int64, Bool, Bool}, ::Downloads.Curl.Easy)@Curl.jl:60
- #8@Downloads.jl:298[inlined]
- arg_write(::Downloads.var"#8#17"{Base.DevNull, Nothing, Vector{Pair{String, String}}, Float64, Nothing, Bool, Bool, String, Int64, Bool, Bool}, ::IOStream)@ArgTools.jl:112
- #7@Downloads.jl:297[inlined]
- arg_read@ArgTools.jl:61[inlined]
- var"#request#5"(::Nothing, ::IOStream, ::Nothing, ::Vector{Pair{String, String}}, ::Float64, ::Nothing, ::Bool, ::Bool, ::Nothing, ::typeof(Downloads.request), ::String)@Downloads.jl:296
- (::Downloads.var"#3#4"{Nothing, Vector{Pair{String, String}}, Float64, Nothing, Bool, Nothing, String})(::IOStream)@Downloads.jl:209
- arg_write(::Downloads.var"#3#4"{Nothing, Vector{Pair{String, String}}, Float64, Nothing, Bool, Nothing, String}, ::Nothing)@ArgTools.jl:101
- #download#2@Downloads.jl:208[inlined]
- download(::String, ::Nothing)@Downloads.jl:208
- #invokelatest#2@essentials.jl:708[inlined]
- invokelatest@essentials.jl:706[inlined]
- do_download@download.jl:33[inlined]
- download@download.jl:29[inlined]
- top-level scope@Local: 1[inlined]
xxxxxxxxxx
serialized_model_path = download("http://cot.llc/mnist_conv")
UndefVarError: serialized_model_path not defined
- top-level scope@Local: 1
xxxxxxxxxx
model = open(io -> deserialize(io), serialized_model_path)
To test our model we will only load in the test data. Our model was trained with training data from MNIST.traindata()
in another notebook.
xxxxxxxxxx
xxxxxxxxxx
test_x, test_y = MNIST.testdata();
test_x
shape: (28, 28, 10000), test_y
shape: (10000,)
xxxxxxxxxx
Testing the model (and building the API too!)
xxxxxxxxxx
First we assign a variable input_images
to a small slice of test data.
xxxxxxxxxx
Start Index:
End Index:
xxxxxxxxxx
1
xxxxxxxxxx
safe_start_index = max(start_index |> default(1), 1)
10
xxxxxxxxxx
safe_end_index = min(end_index |> default(10), length(test_y))
1:10
xxxxxxxxxx
input_images_slice = min(safe_start_index, safe_end_index):max(safe_start_index, safe_end_index)
xxxxxxxxxx
input_images = Flux.unsqueeze(test_x, 3)[:, :, :, input_images_slice];
For example, the first (and only) element in the sample is a 7
xxxxxxxxxx
xxxxxxxxxx
display_digit(input_images[:, :, 1, 1])
Passing our input_images
through the model loaded earlier, we get a 10x1 matrix, where each column corresponds to an input image, and each row corresponds to the class which the model thinks the image corresponds to. For example, a high value in the first row corresponds to a high confidence that the image contains a 0 digit.
The highest value by far is in the 8th index, which corresponds to the model predicting a 7 digit.
xxxxxxxxxx
UndefVarError: model not defined
- top-level scope@Local: 1
xxxxxxxxxx
predictions = model(input_images)
The last step is to convert these predictions into numbers, then compare them to their true labels
xxxxxxxxxx
UndefVarError: predictions not defined
- top-level scope@Local: 1
xxxxxxxxxx
output_labels = Flux.onecold(predictions, 0:9)
7
2
1
0
4
1
4
9
5
9
xxxxxxxxxx
test_labels = test_y[input_images_slice]
Finally we can measure the accuracy of the model by comparing our predictions to the actual labels and finding the average.
xxxxxxxxxx
UndefVarError: output_labels not defined
- top-level scope@Local: 1
xxxxxxxxxx
Int.(output_labels .== test_labels)
UndefVarError: output_labels not defined
- top-level scope@Local: 1
xxxxxxxxxx
accuracy = mean(output_labels .== test_labels)
Helpers
xxxxxxxxxx
default (generic function with 1 method)
xxxxxxxxxx
function default(x)
return y -> (isnothing(y) || isnan(y)) ? x : y
end
display_digit (generic function with 1 method)
xxxxxxxxxx
function display_digit(img)
Gray.(permutedims(img, (2, 1)))
end